import numpy as np
from tqdm import tqdm
from jax import random
from src.datasets.dataset_creator.utils import pad_grid
from typing import List
import jax
import jax.numpy as jnp

import numpy as np
from jax import random
from tqdm import tqdm


def create_data(
    B: int, N: int, seed: int, transformations: List, program_base_index: int = 0, verbose: bool = False
) -> tuple:
    """
    Generate a dataset of object transformations.

    Args:
        B: Batch size
        N: Number of samples (input-output pairs) per batch
        seed: Random seed for reproducibility
        transformations: List of transformation dictionaries
        program_base_index: Base index for program IDs

    Returns:
        data: Numpy array of shape (B, N, 30, 30, 2)
        shapes: Numpy array of shape (B, N, 2, 2)  # [input_shape, output_shape]
        program_ids: Numpy array of shape (B,) containing program IDs
    """

    data = np.zeros((B, N, 30, 30, 2), dtype=np.int32)
    grid_shapes = np.zeros((B, N, 2, 2), dtype=np.int32)
    program_ids = np.zeros(B, dtype=np.int32)

    key = random.PRNGKey(seed)

    for b in tqdm(range(B), desc="Generating batches"):
        batch_valid = False
        max_attempts = 100  # Maximum number of overall attempts per batch

        for attempt in range(max_attempts):

            key, subkey = random.split(key)  # Sample a key for a specific set of programs
            transformation, idx = sample_transformation(subkey, transformations)

            input_generator = transformation["input_generator"]
            program = transformation["program"]

            batch_data = np.zeros((N, 30, 30, 2), dtype=np.int32)
            batch_shapes = np.zeros((N, 2, 2), dtype=np.int32)

            valid_samples = 0

            for _ in range(N):
                sample_attempts = 0
                while sample_attempts < 10:
                    try:
                        input_grid, extras = input_generator()
                        output_grid = program((input_grid, extras), key)

                        if not np.array_equal(input_grid, output_grid):

                            batch_data[valid_samples, :, :, 0] = pad_grid(input_grid, (30, 30))
                            batch_data[valid_samples, :, :, 1] = pad_grid(output_grid, (30, 30))
                            batch_shapes[valid_samples] = np.array(
                                [
                                    [input_grid.shape[0], output_grid.shape[0]],
                                    [input_grid.shape[1], output_grid.shape[1]],
                                ]
                            )
                            valid_samples += 1
                            break
                        else:
                            if verbose:
                                print(
                                    "Input and output grids are equal. Skipping sample: {}".format(
                                        transformation["name"]
                                    )
                                )

                    except Exception as e:
                        print(f"Error in sample generation: {e}")

                    sample_attempts += 1

                if sample_attempts == 10:
                    break  # Break the inner loop if we failed to generate a valid sample after 10 attempts

            if valid_samples == N:
                batch_valid = True
                data[b] = batch_data
                grid_shapes[b] = batch_shapes
                program_ids[b] = idx
                break  # Break the outer loop if we've successfully generated all N samples
            else:
                print(
                    f"Failed to generate valid batch with transformation {transformation['name']}. Retrying with a new transformation."
                )

        if not batch_valid:
            raise ValueError(f"Failed to generate valid batch after {max_attempts} attempts for batch {b}")

    return data, grid_shapes, program_ids + program_base_index


def pad_grid(grid, target_shape):
    padded = np.zeros(target_shape, dtype=grid.dtype)
    padded[: grid.shape[0], : grid.shape[1]] = grid
    return padded


def sample_transformation(key, transformations):
    weights = np.array([t["frequency_weight"] for t in transformations])
    weights_norm = weights / weights.sum()
    idx = random.choice(key, len(transformations), p=weights_norm)
    return transformations[idx], idx


def create_dataset(B_train, B_test, N, seed, baseline_program_id, transformations):

    # Seed to control randomness
    np.random.seed(seed)

    # Generate the training dataset
    dataset_train, grid_shapes_train, program_ids_train = create_data(
        B_train, N, seed, transformations, baseline_program_id
    )

    # Generate the testing dataset
    dataset_test, grid_shapes_test, program_ids_test = create_data(
        B_test, N, seed + 1, transformations, baseline_program_id
    )

    assert np.max(dataset_train) <= 9
    assert np.max(dataset_test) <= 9

    assert np.min(dataset_train) >= 0
    assert np.min(dataset_test) >= 0

    return (dataset_train, grid_shapes_train, program_ids_train), (
        dataset_test,
        grid_shapes_test,
        program_ids_test,
    )
